{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型的存储与读取\n",
    "\n",
    "## 存储完整的模型结构与参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建与存储模型\n",
    "import tensorflow.keras as keras\n",
    "model = keras.Sequential()\n",
    "model.add(keras.layers.Dense(1, input_shape=[5]))\n",
    "\n",
    "model_json = model.to_json()\n",
    "with open('file\\\\fullmodel.json', 'w') as f:\n",
    "    f.write(model_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"class_name\": \"Sequential\", \"config\": {\"name\": \"sequential_5\", \"layers\": [{\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_5\", \"trainable\": true, \"batch_input_shape\": [null, 5], \"dtype\": \"float32\", \"units\": 1, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}]}, \"keras_version\": \"2.2.4-tf\", \"backend\": \"tensorflow\"}'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看存储结果\n",
    "with open('file\\\\fullmodel.json', 'r') as f:\n",
    "    content = f.read()\n",
    "content"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 同时读取模型结构与参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_5\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_5 (Dense)              (None, 1)                 6         \n",
      "=================================================================\n",
      "Total params: 6\n",
      "Trainable params: 6\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.models import model_from_json\n",
    "\n",
    "with open('file\\\\fullmodel.json', 'r') as f:\n",
    "    content = f.read()\n",
    "    model = model_from_json(model_json)\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 只保存和读取参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建模型结构\n",
    "import tensorflow.keras as keras\n",
    "\n",
    "def create_model():\n",
    "    model = keras.Sequential()\n",
    "    model.add(keras.layers.Dense(1, input_shape=[5]))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建模型并存储参数\n",
    "model_1 = create_model()\n",
    "model_1.save_weights('file\\\\weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建模型并读取参数\n",
    "model_2 = create_model()\n",
    "model_2.load_weights('file\\\\weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
